{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3998d94-1802-493f-9a0c-92fe27f3550b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练unet模型\n",
    "# 1.搭建unet模型\n",
    "# 2.自定义loss 函数\n",
    "# 3.开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07a612c-7b80-4892-91b6-2339b3dddbdb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4558c3dd-0b29-4810-b566-23b05985f02c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 仍然是加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca4c93f5-52ba-4512-9c41-926bea5b9b7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\conda\\envs\\unet_env\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import glob\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "from imgaug.augmentables.segmaps import SegmentationMapsOnImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ef35425a-5971-4729-9163-ab2b5a1d38f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SegmentDataset(Dataset):\n",
    "\n",
    "    def __init__(self,where='train',seq=None):\n",
    "        # 获取数据\n",
    "        self.img_list = glob.glob('processed/{}/*/img_*'.format(where))\n",
    "        self.mask_list = glob.glob('processed/{}/*/img_*')\n",
    "        # 数据增强pipeline\n",
    "        self.seq = seq\n",
    "\n",
    "    def __len__(self):\n",
    "        # 返回数据大小\n",
    "        return len(self.img_list)\n",
    "    \n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # 获取具体每一个数据\n",
    "        \n",
    "        # 获取图片\n",
    "        img_file = self.img_list[idx]\n",
    "        mask_file = img_file.replace('img','label')\n",
    "        img = np.load(img_file)\n",
    "        # 获取mask\n",
    "        mask = np.load(mask_file)\n",
    "        \n",
    "        # 如果需要数据增强\n",
    "        if self.seq:\n",
    "            segmap = SegmentationMapsOnImage(mask, shape=mask.shape)\n",
    "            img,mask = seq(image=img, segmentation_maps=segmap)\n",
    "            # 直接获取数组内容\n",
    "            mask =  mask.get_arr()\n",
    "        \n",
    "        # 灰度图扩张维度成张量\n",
    "        return np.expand_dims(img,0) , np.expand_dims(mask,0)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "def6cdc4-2085-4f90-b10a-d3d8a95c4417",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据增强处理流程\n",
    "seq = iaa.Sequential([\n",
    "    iaa.Affine(scale=(0.8, 1.2), # 缩放\n",
    "               rotate=(-45, 45)),  # 旋转\n",
    "    iaa.ElasticTransformation()  # 变换\n",
    "                ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fdcb2ab5-3a57-4ba6-90bf-a215379f2583",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用dataloader加载\n",
    "batch_size = 12\n",
    "num_workers = 0\n",
    "\n",
    "train_dataset = SegmentDataset('train',seq)\n",
    "test_dataset = SegmentDataset('test',None)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e439ccb4-3e9c-4d0b-973b-1c4a32aef81b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "912306e4-354b-4d5e-83f5-adf4dc1c7e8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8fcbc5e1-c836-4388-b006-a798cc0aac9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义两次卷积操作\n",
    "class ConvBlock(torch.nn.Module):\n",
    "    def __init__(self,in_channels,out_channels):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.step = torch.nn.Sequential(\n",
    "            # 第一次卷积\n",
    "            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),\n",
    "            # ReLU\n",
    "            torch.nn.ReLU(),\n",
    "            # 第二次卷积\n",
    "            torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),\n",
    "            # ReLU\n",
    "            torch.nn.ReLU()\n",
    "        )\n",
    "    \n",
    "    def forward(self,x):\n",
    "        \n",
    "        return self.step(x)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa99a9f5-e460-441b-a58e-e9fe85b201db",
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 定义左侧编码器的操作\n",
    "        self.layer1 = ConvBlock(1,64)\n",
    "        self.layer2 = ConvBlock(64,128)\n",
    "        self.layer3 = ConvBlock(128,256)\n",
    "        self.layer4 = ConvBlock(256,512)\n",
    "        \n",
    "        # 定义右侧解码器的操作\n",
    "        self.layer5 = ConvBlock(256+512,256)\n",
    "        self.layer6 = ConvBlock(128+256,128)\n",
    "        self.layer7 = ConvBlock(64+128,64)\n",
    "        \n",
    "        #最后一个卷积\n",
    "        self.layer8  = torch.nn.Conv2d(in_channels=64,out_channels=1,kernel_size=1,padding=0,stride=1)\n",
    "        \n",
    "        # 定一些其他操作\n",
    "        # 池化\n",
    "        self.maxpool = torch.nn.MaxPool2d(kernel_size=2)\n",
    "        #上采样\n",
    "        self.upsample = torch.nn.Upsample(scale_factor=2,mode='bilinear')\n",
    "        # sigmoid\n",
    "        self.sigmoid = torch.nn.Sigmoid()\n",
    "    \n",
    "    def forward(self,x):\n",
    "        # 对输入数据进行处理\n",
    "        \n",
    "        # 定义下采样部分\n",
    "        \n",
    "        # input:1X256x256, output: 64x256x256\n",
    "        x1 = self.layer1(x)\n",
    "        # input:64x256x256, output: 64 x 128 x 128\n",
    "        x1_p = self.maxpool(x1)\n",
    "        \n",
    "        # input:  64 x 128 x 128 , output: 128 x 128 x 128\n",
    "        x2 = self.layer2(x1_p)\n",
    "        # input:128 x 128 x 128 , output: 128 x 64 x 64\n",
    "        x2_p = self.maxpool(x2)\n",
    "        \n",
    "        # input: 128 x 64 x 64, output: 256 x 64 x 64\n",
    "        x3 = self.layer3(x2_p)\n",
    "        #input:256 x 64 x 64, output: 256 x 32 x 32\n",
    "        x3_p = self.maxpool(x3)\n",
    "        \n",
    "        #input: 256 x 32 x 32, output: 512 x 32 x 32\n",
    "        x4 = self.layer4(x3_p)\n",
    "        \n",
    "        \n",
    "        \n",
    "        # 定义上采样\n",
    "        # input: 512 x 32 x 32，output: 512 x 64 x 64\n",
    "        x5 = self.upsample(x4)\n",
    "        # 拼接,output: 768x 64 x 64\n",
    "        x5 = torch.cat([x5,x3],dim=1)\n",
    "        # input: 768x 64 x 64,output: 256 x 64 x 64\n",
    "        x5 = self.layer5(x5)\n",
    "        \n",
    "        # input: 256 x 64 x 64,output: 256 x 128 x 128\n",
    "        x6  = self.upsample(x5)\n",
    "        # 拼接,output: 384 x 128 x 128\n",
    "        x6 = torch.cat([x6,x2],dim=1)\n",
    "        # input: 384 x 128 x 128, output: 128 x 128 x 128\n",
    "        x6 = self.layer6(x6)\n",
    "        \n",
    "        \n",
    "        # input:128 x 128 x 128, output: 128 x 256 x 256\n",
    "        x7 = self.upsample(x6)\n",
    "        # 拼接, output: 192 x 256 x256\n",
    "        x7 = torch.cat([x7,x1],dim=1)\n",
    "        # input: 192 x 256 x256, output: 64 x 256 x 256\n",
    "        x7 = self.layer7(x7)\n",
    "        \n",
    "        # 最后一次卷积,input: 64 x 256 x 256, output: 1 x 256 x 256\n",
    "        x8 = self.layer8(x7)\n",
    "        \n",
    "        #sigmoid\n",
    "        # x9= self.sigmoid(x8)\n",
    "        \n",
    "        \n",
    "        \n",
    "        return x8\n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "72a4af0b-b928-40d2-9963-71c2d15bd141",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试模型\n",
    "\n",
    "# 模型架构可视化\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3bf102ba-0650-4e78-8bff-a254d445a2c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# device\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0bf1d7b5-e93d-4040-bede-b9972f877856",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = UNet().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bc3edb2e-a32b-4967-8bf4-aa26d32e3aa4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 256, 256]             640\n",
      "              ReLU-2         [-1, 64, 256, 256]               0\n",
      "            Conv2d-3         [-1, 64, 256, 256]          36,928\n",
      "              ReLU-4         [-1, 64, 256, 256]               0\n",
      "        DoubleConv-5         [-1, 64, 256, 256]               0\n",
      "         MaxPool2d-6         [-1, 64, 128, 128]               0\n",
      "            Conv2d-7        [-1, 128, 128, 128]          73,856\n",
      "              ReLU-8        [-1, 128, 128, 128]               0\n",
      "            Conv2d-9        [-1, 128, 128, 128]         147,584\n",
      "             ReLU-10        [-1, 128, 128, 128]               0\n",
      "       DoubleConv-11        [-1, 128, 128, 128]               0\n",
      "        MaxPool2d-12          [-1, 128, 64, 64]               0\n",
      "           Conv2d-13          [-1, 256, 64, 64]         295,168\n",
      "             ReLU-14          [-1, 256, 64, 64]               0\n",
      "           Conv2d-15          [-1, 256, 64, 64]         590,080\n",
      "             ReLU-16          [-1, 256, 64, 64]               0\n",
      "       DoubleConv-17          [-1, 256, 64, 64]               0\n",
      "        MaxPool2d-18          [-1, 256, 32, 32]               0\n",
      "           Conv2d-19          [-1, 512, 32, 32]       1,180,160\n",
      "             ReLU-20          [-1, 512, 32, 32]               0\n",
      "           Conv2d-21          [-1, 512, 32, 32]       2,359,808\n",
      "             ReLU-22          [-1, 512, 32, 32]               0\n",
      "       DoubleConv-23          [-1, 512, 32, 32]               0\n",
      "           Conv2d-24          [-1, 256, 64, 64]       1,769,728\n",
      "             ReLU-25          [-1, 256, 64, 64]               0\n",
      "           Conv2d-26          [-1, 256, 64, 64]         590,080\n",
      "             ReLU-27          [-1, 256, 64, 64]               0\n",
      "       DoubleConv-28          [-1, 256, 64, 64]               0\n",
      "           Conv2d-29        [-1, 128, 128, 128]         442,496\n",
      "             ReLU-30        [-1, 128, 128, 128]               0\n",
      "           Conv2d-31        [-1, 128, 128, 128]         147,584\n",
      "             ReLU-32        [-1, 128, 128, 128]               0\n",
      "       DoubleConv-33        [-1, 128, 128, 128]               0\n",
      "           Conv2d-34         [-1, 64, 256, 256]         110,656\n",
      "             ReLU-35         [-1, 64, 256, 256]               0\n",
      "           Conv2d-36         [-1, 64, 256, 256]          36,928\n",
      "             ReLU-37         [-1, 64, 256, 256]               0\n",
      "       DoubleConv-38         [-1, 64, 256, 256]               0\n",
      "           Conv2d-39          [-1, 1, 256, 256]              65\n",
      "================================================================\n",
      "Total params: 7,781,761\n",
      "Trainable params: 7,781,761\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.25\n",
      "Forward/backward pass size (MB): 594.50\n",
      "Params size (MB): 29.69\n",
      "Estimated Total Size (MB): 624.44\n",
      "----------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\conda\\envs\\unet_env\\lib\\site-packages\\torch\\nn\\functional.py:3454: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "summary(model,(1, 256, 256))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d720f1d4-ffe1-457b-ab5a-fab772718c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模拟输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "baec1de3-8375-4c35-bd9c-29f2ac7fd875",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_input = torch.randn(1, 1, 256, 256).to(device)\n",
    "output = model(random_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8ab126d2-63e0-4028-8013-33b37d7a9664",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1, 256, 256])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "782e7632-f0f1-470f-9e36-aa0b0eab97dd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbd29a25-15a6-42db-917a-bbb0d0f24742",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "666ec629-b023-4a5c-ac1f-df9d9e51e7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0cec0e3e-cf75-4c77-8740-a5e05a3d2f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b92b4fa0-afdd-4f96-9537-b0c76d323821",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = torch.nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0bb17061-6326-4a98-bc40-2e2046851841",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "16b9e6e6-37e3-4d7f-87fa-496e56e25386",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "11256f91-6ffc-4546-8650-b63ccb26e413",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 动态减少LR\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "scheduler = ReduceLROnPlateau(optimizer, 'min')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "35509dfa-5105-472a-9e20-cd2c0cb6de13",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "414cf9ea-e01e-4b0b-b9e7-035a47db550f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算测试集的loss\n",
    "def check_test_loss(loader,model):\n",
    "    loss = 0\n",
    "    # 不记录梯度\n",
    "    with torch.no_grad():\n",
    "        for i, (x, y) in enumerate(loader):\n",
    "            # 图片\n",
    "            x = x.to(device,dtype=torch.float32)\n",
    "            # 标签\n",
    "            y = y.to(device,dtype=torch.float32)\n",
    "            # 预测值\n",
    "            y_pred = model(x)\n",
    "            #计算损失\n",
    "            loss_batch = loss_fn(y_pred, y)\n",
    "            \n",
    "            loss += loss_batch\n",
    "    return loss / len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1e0cf0fd-dd53-4278-81d9-662766d38ed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用tensorboard记录参数\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c07e2628-52eb-4547-974a-4a1e33ce2001",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 记录变量\n",
    "writer = SummaryWriter(log_dir='./log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d06ae355-3537-431d-a7b0-681f24fddd20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "2d1e74b1-23fb-4df7-a5c4-16418b7ec4c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6476278305053711\n",
      "0.6448570489883423\n",
      "0.6427306532859802\n",
      "0.6413723826408386\n",
      "0.6385133266448975\n",
      "0.6362111568450928\n",
      "0.6347560882568359\n",
      "0.6316084861755371\n",
      "0.6278462409973145\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [28]\u001b[0m, in \u001b[0;36m<cell line: 6>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     27\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m     29\u001b[0m \u001b[38;5;66;03m# 记录每个batch的train loss\u001b[39;00m\n\u001b[1;32m---> 30\u001b[0m loss_batch \u001b[38;5;241m=\u001b[39m \u001b[43mloss_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     31\u001b[0m \u001b[38;5;66;03m# 打印\u001b[39;00m\n\u001b[0;32m     32\u001b[0m \u001b[38;5;28mprint\u001b[39m(loss_batch\u001b[38;5;241m.\u001b[39mitem())\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 训练100个epoch\n",
    "EPOCH_NUM = 200\n",
    "# 记录最好的测试acc\n",
    "best_test_loss = 100\n",
    "\n",
    "for epoch in range(EPOCH_NUM):\n",
    "    # 获取批次图像\n",
    "    start_time = time.time()\n",
    "    loss = 0\n",
    "    for i, (x, y) in enumerate(train_loader):\n",
    "        # ！！！每次update前清空梯度\n",
    "        model.zero_grad()\n",
    "        # 获取数据\n",
    "        # 图片\n",
    "        x = x.to(device,dtype=torch.float32)\n",
    "        # 标签\n",
    "        y = y.to(device,dtype=torch.float32)\n",
    "        # 预测值\n",
    "        y_pred = model(x)\n",
    "        #计算损失\n",
    "        loss_batch = loss_fn(y_pred, y)\n",
    "        \n",
    "        # 计算梯度\n",
    "        loss_batch.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # 记录每个batch的train loss\n",
    "        loss_batch = loss_batch.detach().cpu()\n",
    "        # 打印\n",
    "        print(loss_batch.item())\n",
    "        loss += loss_batch\n",
    "        \n",
    "    # 每个epoch的loss\n",
    "    loss = loss / len(train_loader)\n",
    "    # 如果降低LR：如果loss连续10个epoch不再下降，就减少LR\n",
    "    scheduler.step(loss)\n",
    "    \n",
    "    # 计算测试集的loss\n",
    "    test_loss = check_test_loss(test_loader,model)\n",
    "    \n",
    "    # tensorboard 记录 Loss/train\n",
    "    writer.add_scalar('Loss/train', loss, epoch)\n",
    "    # tensorboard 记录 Loss/test\n",
    "    writer.add_scalar('Loss/test', test_loss, epoch)\n",
    "    \n",
    "     # 记录最好的测试loss，并保存模型\n",
    "    if best_test_loss > test_loss:\n",
    "        best_test_loss = test_loss\n",
    "        # 保存模型\n",
    "        torch.save(model.state_dict(), './save_model/unet_best.pt')\n",
    "        print('第{}个EPOCH达到最低的测试loss:{}'.format(epoch,best_test_loss))\n",
    "    \n",
    "    # 打印信息\n",
    "    print('第{}个epoch执行时间：{}s，train loss为：{}，test loss为：{}'.format(\n",
    "        epoch,\n",
    "        time.time()-start_time,\n",
    "        loss,\n",
    "        test_loss\n",
    "    ) )\n",
    "    # 保存最新模型\n",
    "    torch.save(model.state_dict(), './save_model/unet_latest.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772c8a46-2b0f-4f92-8607-993103c724f2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87b14d9-28ae-4f53-b17c-348e304a732c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
